﻿using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace VeloxDev.Core.Generator
{
    [Generator(LanguageNames.CSharp)]
    public sealed class AOTReflection : IIncrementalGenerator
    {
        public void Initialize(IncrementalGeneratorInitializationContext context)
        {
            // Step 1: 收集项目中带有 [AOTReflection] 的类型
            var localTypes = context.SyntaxProvider
                .CreateSyntaxProvider(
                    static (s, _) => s is ClassDeclarationSyntax cds && cds.AttributeLists.Count > 0,
                    static (ctx, _) =>
                    {
                        var cds = (ClassDeclarationSyntax)ctx.Node;
                        foreach (var list in cds.AttributeLists)
                        {
                            foreach (var attr in list.Attributes)
                            {
                                var symbol = ctx.SemanticModel.GetSymbolInfo(attr).Symbol?.ContainingType;
                                if (symbol is null) continue;

                                if (symbol.ToDisplayString() == "VeloxDev.Core.AOT.AOTReflectionAttribute")
                                {
                                    var typeSymbol = ctx.SemanticModel.GetDeclaredSymbol(cds) as INamedTypeSymbol;
                                    if (typeSymbol is not null)
                                        return typeSymbol;
                                }
                            }
                        }
                        return null;
                    })
                .Where(static t => t is not null)!;

            // Step 2: 组合 Compilation 与 本地标记类型
            var compilationAndLocal = context.CompilationProvider.Combine(localTypes.Collect());

            // Step 3: 注册输出
            context.RegisterSourceOutput(compilationAndLocal, static (spc, pair) =>
            {
                var (compilation, localTypeSymbols) = pair;

                // 仅扫描当前项目（不包含引用程序集）
                var allMarkedTypes = new HashSet<INamedTypeSymbol>(SymbolEqualityComparer.Default);
                foreach (var t in localTypeSymbols)
                    allMarkedTypes.Add(t!);

                // 没有标记类型则不生成文件
                if (allMarkedTypes.Count == 0)
                    return;

                // 获取命名空间（智能推导或用户指定）
                var ns = InferNamespace(allMarkedTypes);

                // 生成源文件
                var source = GenerateAOTReflectionSource(allMarkedTypes, ns);
                spc.AddSource($"{ns}.AOTReflection.g.cs", SourceText.From(source, Encoding.UTF8));
            });
        }

        private static string InferNamespace(IEnumerable<INamedTypeSymbol> types)
        {
            // 若有用户在构造函数中显式指定命名空间，则直接使用
            foreach (var type in types)
            {
                var attr = type.GetAttributes().FirstOrDefault(a =>
                    a.AttributeClass?.ToDisplayString() == "VeloxDev.Core.AOT.AOTReflectionAttribute");
                if (attr is not null && attr.ConstructorArguments.Length > 0)
                {
                    var firstArg = attr.ConstructorArguments[0];
                    if (firstArg.Kind == TypedConstantKind.Primitive && firstArg.Value is string s && !string.IsNullOrWhiteSpace(s) && s != "Auto")
                        return s;
                }
            }

            // 自动推导命名空间：取最长公共前缀
            var namespaces = types
                .Select(t => t.ContainingNamespace?.ToDisplayString() ?? string.Empty)
                .Where(n => !string.IsNullOrEmpty(n))
                .ToList();

            return GetCommonPrefixNamespace(namespaces) ?? "GlobalAOT";
        }

        private static string? GetCommonPrefixNamespace(List<string> namespaces)
        {
            if (namespaces.Count == 0)
                return null;

            var split = namespaces.Select(n => n.Split('.')).ToList();
            var minLen = split.Min(a => a.Length);
            var prefix = new List<string>();

            for (int i = 0; i < minLen; i++)
            {
                var segment = split[0][i];
                if (split.All(s => s[i] == segment))
                    prefix.Add(segment);
                else
                    break;
            }

            return prefix.Count > 0 ? string.Join(".", prefix) : null;
        }

        private static string GenerateAOTReflectionSource(IEnumerable<INamedTypeSymbol> types, string ns)
        {
            var sb = new StringBuilder();
            sb.AppendLine("// <auto-generated />");
            sb.AppendLine("using System;");
            sb.AppendLine("using System.Reflection;");
            sb.AppendLine();
            sb.AppendLine($"namespace {ns}");
            sb.AppendLine("{");
            sb.AppendLine("    /// <summary>");
            sb.AppendLine("    /// 自动生成的 AOT 反射初始化类");
            sb.AppendLine("    /// </summary>");
            sb.AppendLine("    public static class AOTReflection");
            sb.AppendLine("    {");
            sb.AppendLine("        public static void Init()");
            sb.AppendLine("        {");

            foreach (var type in types.Distinct(SymbolEqualityComparer.Default))
            {
                var fullName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
                sb.AppendLine($"            // [{type.ContainingNamespace?.ToDisplayString()}] {type.Name}");
                sb.AppendLine($"            _ = typeof({fullName}).GetTypeInfo();");

                var attr = type.GetAttributes().FirstOrDefault(a =>
                    a.AttributeClass?.ToDisplayString() == "VeloxDev.Core.AOT.AOTReflectionAttribute");

                // 读取构造函数参数（支持主构造函数）
                string? nsArg = null;
                bool includeCtors = false, includeMethods = false, includeProps = false, includeFields = false;

                if (attr != null && attr.ConstructorArguments.Length > 0)
                {
                    if (attr.ConstructorArguments[0].Value is string s)
                        nsArg = s;
                    includeCtors = GetCtorArg(attr, 1);
                    includeMethods = GetCtorArg(attr, 2);
                    includeProps = GetCtorArg(attr, 3);
                    includeFields = GetCtorArg(attr, 4);
                }

                // 支持命名参数覆盖
                foreach (var kv in attr?.NamedArguments ?? [])
                {
                    switch (kv.Key)
                    {
                        case "IncludeConstructors": includeCtors = (bool)kv.Value.Value!; break;
                        case "IncludeMethods": includeMethods = (bool)kv.Value.Value!; break;
                        case "IncludeProperties": includeProps = (bool)kv.Value.Value!; break;
                        case "IncludeFields": includeFields = (bool)kv.Value.Value!; break;
                    }
                }

                // 根据选项生成反射代码
                if (includeCtors)
                    sb.AppendLine($"            _ = typeof({fullName}).GetConstructors(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static);");
                if (includeMethods)
                    sb.AppendLine($"            _ = typeof({fullName}).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static);");
                if (includeProps)
                    sb.AppendLine($"            _ = typeof({fullName}).GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static);");
                if (includeFields)
                    sb.AppendLine($"            _ = typeof({fullName}).GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static);");

                sb.AppendLine();
            }

            sb.AppendLine("        }");
            sb.AppendLine("    }");
            sb.AppendLine("}");
            return sb.ToString();
        }

        private static bool GetCtorArg(AttributeData? attr, int index)
        {
            if (attr == null || attr.ConstructorArguments.Length <= index)
                return false;
            var arg = attr.ConstructorArguments[index];
            return arg.Value is bool b && b;
        }
    }
}
